package face7wap;

import face5wap.LossGradientPenalty;
import face6wap.MnistOneDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.nn.conf.BackpropType;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration.GraphBuilder;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.graph.StackVertex;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.distribution.impl.NormalDistribution;
import org.nd4j.linalg.api.rng.distribution.impl.UniformDistribution;
import org.nd4j.linalg.dataset.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.learning.config.RmsProp;
import org.nd4j.linalg.learning.config.Sgd;
import org.nd4j.linalg.lossfunctions.ILossFunction;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.linalg.lossfunctions.impl.LossWasserstein;
import util.ShowUtils;
import util.ShowUtilsHanf;

import java.io.File;
import java.util.Map;

/**
 修改输入值大小
 https://liuxiaofei.com.cn/blog/wgan-gp%E4%B8%8Ewgan%E7%9A%84%E5%8C%BA%E5%88%AB

 https://github.com/bojone/gan/blob/master/mnist_gangp.py
 */
public class Gan_7Dis_gp_loss {

	static double lr = 0.00005;
	static String model = "F:/face/gan.zip";
	public static void main(String[] args) throws Exception {
		Nd4j.getMemoryManager().setAutoGcWindow(15 * 1000);
		IUpdater updater = Adam.builder().learningRate(lr).beta1(0.5).build(); //new RmsProp(lr);//

		final GraphBuilder dis = new NeuralNetConfiguration.Builder().updater(updater)
				.weightInit(WeightInit.XAVIER).graphBuilder().backpropType(BackpropType.Standard)
				.addInputs("fake")
				.addLayer("d1",
						new DenseLayer.Builder().nIn(28 * 28).nOut(256).activation(Activation.RELU)
								.weightInit(WeightInit.XAVIER).build(),
						"fake")
				.addLayer("d2",
						new DenseLayer.Builder().nIn(256).nOut(128).activation(Activation.RELU)
								.weightInit(WeightInit.XAVIER).build(),
						"d1")
				.addLayer("d3",
						new DenseLayer.Builder().nIn(128).nOut(128).activation(Activation.RELU)
								.weightInit(WeightInit.XAVIER).build(),
						"d2")
				.addLayer("out", new OutputLayer.Builder(new Loss7Wasserstein()).nIn(128).nOut(1)
						.activation(Activation.RELU).build(), "d3")
				.setOutputs("out");


		final GraphBuilder gen = new NeuralNetConfiguration.Builder().updater(updater)
				.weightInit(WeightInit.XAVIER).graphBuilder().backpropType(BackpropType.Standard)
				.addInputs("fake")
				.addLayer("g1",
						new DenseLayer.Builder().nIn(28 * 28).nOut(128).activation(Activation.RELU)
								.weightInit(WeightInit.XAVIER).build(),
						"fake")
				.addLayer("g2",
						new DenseLayer.Builder().nIn(128).nOut(512).activation(Activation.RELU)
								.weightInit(WeightInit.XAVIER).build(),
						"g1")
				.addLayer("g3",
						new DenseLayer.Builder().nIn(512).nOut(28 * 28).activation(Activation.RELU)
								.weightInit(WeightInit.XAVIER).build(),
						"g2")
				.addLayer("d1",
						new DenseLayer.Builder().nIn(28 * 28).nOut(256).activation(Activation.RELU)
								.weightInit(WeightInit.XAVIER).build(),
						"g3")
				.addLayer("d2",
						new DenseLayer.Builder().nIn(256).nOut(128).activation(Activation.RELU)
								.weightInit(WeightInit.XAVIER).build(),
						"d1")
				.addLayer("d3",
						new DenseLayer.Builder().nIn(128).nOut(128).activation(Activation.RELU)
								.weightInit(WeightInit.XAVIER).build(),
						"d2")
				.addLayer("out", new OutputLayer.Builder(new Loss7Wasserstein()).nIn(128).nOut(1)
						.activation(Activation.RELU).build(), "d3")
				.setOutputs("out");

		ComputationGraph net = new ComputationGraph(dis.build());
		ComputationGraph net1 = new ComputationGraph(gen.build());

		net.init();
		net1.init();

		System.out.println(net.summary());
		System.out.println(net1.summary());

		net.setListeners(new ScoreIterationListener(100));
		net1.setListeners(new ScoreIterationListener(100));

		DataSetIterator train = new MnistDataSetIterator(30, true, 12345);
		//按垂直方向（行顺序）堆叠数组构成一个新的数组
		//INDArray labelD = Nd4j.vstack(Nd4j.ones(30, 1), Nd4j.zeros(30, 1));
		//INDArray labelG = Nd4j.ones(30 * 3, 1);
		/*INDArray real = Nd4j.vstack(Nd4j.ones(30 , 1).muli(-1));
		INDArray fake = Nd4j.vstack(Nd4j.ones(30, 1));
		INDArray dumpy = Nd4j.vstack(Nd4j.ones(30, 1));
		INDArray all = Nd4j.vstack(fake, real,dumpy);*/

		INDArray valid = Nd4j.ones(30 , 1).muli(-1);
		INDArray fake =  Nd4j.ones(30 , 1);
		INDArray dummy = Nd4j.zeros(30, 1).muli(0.5);




		for (int i = 1; i <= 100000; i++) {
			if (!train.hasNext()) {
				train.reset();
			}


			for(int m=0;m<10;m++){
				INDArray real_img = train.next().getFeatures();
				INDArray z = Nd4j.rand(new NormalDistribution(),new long[] { 30, 28*28 });
				INDArray fake_img =  net1.feedForward(new INDArray[] {z}, false).get("g3");// .reshape(20,28,28);

				//1 使用随机方式把真实图片和伪造图片混合在一起。
				INDArray interpolated_img = randomWeightedAverage(30,real_img,fake_img);
				/*((OutputLayer)net.getLayer("out")).setLossFn(new Loss7GradientPenalty());*/
				((org.deeplearning4j.nn.layers.OutputLayer)net.getLayer("out")).layerConf().setLossFn(new Loss7Wasserstein());
				//INDArray validity_interpolated =  net.feedForward(new INDArray[] {interpolated_img}, false).get("out");// .reshape(20,28,28);
				//2 对真实的图片real_img，伪造的图片fake_img和混合的图片interpolated_img创建鉴别网络critic_model进行鉴别。
				//它们的损失函数分别为wasserstein_loss，wasserstein_loss和partial_gp_loss。
				//预测结果分别为valid，fake和validity_interpolated
				MultiDataSet realD = new MultiDataSet(new INDArray[] {real_img},new INDArray[] { valid });
				MultiDataSet fakeD = new MultiDataSet(new INDArray[] {fake_img},new INDArray[] { fake });
				MultiDataSet interD = new MultiDataSet(new INDArray[] {interpolated_img},new INDArray[] { dummy });
				//3 梯度惩罚损失函数gradient_penalty_loss只需要算预测值y_pred关于输入的混合图片averaged_samples的梯度。
				//根据梯度gradients计算欧几里德距离gradient_l2_norm，
				// 然后把这个距离和1比较，显然越靠近1，损失越小，惩罚越小。即既不让梯度过快的变化，也不要过慢的变化，刚好满足1-lipschitz最好。
			 	((org.deeplearning4j.nn.layers.OutputLayer)net.getLayer("out")).layerConf().setLossFn(new Loss7Wasserstein());
				trainD(net, realD);
				trainD(net, fakeD);
				((org.deeplearning4j.nn.layers.OutputLayer)net.getLayer("out")).layerConf().setLossFn(new Loss7GradientPenalty());
				trainD(net, interD);
			}

			net1.getLayer("d1").setParam("W", net.getLayer("d1").getParam("W"));
			net1.getLayer("d1").setParam("b", net.getLayer("d1").getParam("b"));
			net1.getLayer("d2").setParam("W", net.getLayer("d2").getParam("W"));
			net1.getLayer("d2").setParam("b", net.getLayer("d2").getParam("b"));
			net1.getLayer("d3").setParam("W", net.getLayer("d3").getParam("W"));
			net1.getLayer("d3").setParam("b", net.getLayer("d3").getParam("b"));
			net1.getLayer("out").setParam("W", net.getLayer("out").getParam("W"));
			net1.getLayer("out").setParam("b", net.getLayer("out").getParam("b"));

			INDArray z = Nd4j.rand(new NormalDistribution(),new long[] { 30, 28*28 });
			MultiDataSet dataSetG = new MultiDataSet(new INDArray[] {z},
					new INDArray[] { valid });
			trainG(net1, dataSetG);

		/*	net.getLayer("g1").setParam("W", net1.getLayer("g1").getParam("W"));
			net.getLayer("g1").setParam("b", net1.getLayer("g1").getParam("b"));
			net.getLayer("g2").setParam("W", net1.getLayer("g2").getParam("W"));
			net.getLayer("g2").setParam("b", net1.getLayer("g2").getParam("b"));
			net.getLayer("g3").setParam("W", net1.getLayer("g3").getParam("W"));
			net.getLayer("g3").setParam("b", net1.getLayer("g3").getParam("b"));*/

			if (i % 10 == 0) {

				INDArray noise =  Nd4j.rand(new NormalDistribution(),new long[] { 10,28*28});
				INDArray noise1 =  Nd4j.rand(new NormalDistribution(),new long[] {10,28*28});
				/*INDArray[] samps =  gen.output(noise);*/
				/*long[] shpaes = samps[0].shape();
				INDArray[] samples = new INDArray[(int)samps.length];
				for (int j = 0; j < samps.length; j++) {
					samples[j] = samps[j];
				}*/
				INDArray indArray2 = net1.feedForward(new INDArray[] {noise}, false).get("g3");// .reshape(20,28,28);
				INDArray[] samples = new INDArray[(int)indArray2.size(0)];

				samples[0] = indArray2;

				ShowUtilsHanf.visualize(samples,"拆分");
			}
			if (i % 10000 == 0) {
			   net.save(new File(model), true);
			}
 
		}
 
	}
 	// 判别模型  D(x)
	public static void trainD(ComputationGraph net, MultiDataSet dataSet) {
		net.setLearningRate("d1", lr);
		net.setLearningRate("d2", lr);
		net.setLearningRate("d3", lr);
		net.setLearningRate("out", lr);
		net.fit(dataSet);
	}
	//生成模型 g(z)
	public static void trainG(ComputationGraph net, MultiDataSet dataSet) {
		net.setLearningRate("g1", lr);
		net.setLearningRate("g2", lr);
		net.setLearningRate("g3", lr);
		net.setLearningRate("d1", 0);
		net.setLearningRate("d2", 0);
		net.setLearningRate("d3", 0);
		net.setLearningRate("out", 0);
		net.fit(dataSet);
	}
	public static INDArray randomWeightedAverage(long batch,INDArray real,INDArray fake){
		// INDArray alpha = Nd4j.rand(new UniformDistribution(0,1),new long[]{batch, 1, 1, 1});// new NDRandom().uniform(32, 1, DataType.FLOAT, new long[]{32, 1, 1, 1});
		INDArray alpha = Nd4j.rand(new NormalDistribution(0,1),new long[]{batch, 784});// new NDRandom().uniform(32, 1, DataType.FLOAT, new long[]{32, 1, 1, 1});
		return (alpha.muli(real)).addi((Nd4j.ones(alpha.shape()).subi(alpha)).muli(fake));
	}
}